{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "baseline思路：使用CNN进行定长字符分类；\n",
    "\n",
    "运行系统要求：Python2/3，内存4G，有无GPU都可以"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:23:32.395260Z",
     "start_time": "2020-05-09T14:23:31.939967Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "import os, sys, glob, shutil, json\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n",
    "import cv2\n",
    "\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm, tqdm_notebook\n",
    "\n",
    "%pylab inline\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义读取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:23:34.383443Z",
     "start_time": "2020-05-09T14:23:34.377373Z"
    }
   },
   "outputs": [],
   "source": [
    "class SVHNDataset(Dataset):\n",
    "    def __init__(self, img_path, img_label, transform=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label \n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        lbl = np.array(self.img_label[index], dtype=np.int)\n",
    "        lbl = list(lbl)  + (5 - len(lbl)) * [10]\n",
    "        return img, torch.from_numpy(np.array(lbl[:5]))\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义读取数据dataloader\n",
    "\n",
    "假设数据存放在`../input`文件夹下，并进行解压。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:23:37.248125Z",
     "start_time": "2020-05-09T14:23:36.943592Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30000 30000\n",
      "10000 10000\n"
     ]
    }
   ],
   "source": [
    "train_path = glob.glob('../input/train/*.png')\n",
    "train_path.sort()\n",
    "train_json = json.load(open('../input/train.json'))\n",
    "train_label = [train_json[x]['label'] for x in train_json]\n",
    "print(len(train_path), len(train_label))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    SVHNDataset(train_path, train_label,\n",
    "                transforms.Compose([\n",
    "                    transforms.Resize((64, 128)),\n",
    "                    transforms.RandomCrop((60, 120)),\n",
    "                    transforms.ColorJitter(0.3, 0.3, 0.2),\n",
    "                    transforms.RandomRotation(10),\n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    ])), \n",
    "    batch_size=40, \n",
    "    shuffle=True, \n",
    "    num_workers=10,\n",
    ")\n",
    "\n",
    "val_path = glob.glob('../input/val/*.png')\n",
    "val_path.sort()\n",
    "val_json = json.load(open('../input/val.json'))\n",
    "val_label = [val_json[x]['label'] for x in val_json]\n",
    "print(len(val_path), len(val_label))\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    SVHNDataset(val_path, val_label,\n",
    "                transforms.Compose([\n",
    "                    transforms.Resize((60, 120)),\n",
    "                    # transforms.ColorJitter(0.3, 0.3, 0.2),\n",
    "                    # transforms.RandomRotation(5),\n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    ])), \n",
    "    batch_size=40, \n",
    "    shuffle=False, \n",
    "    num_workers=10,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 定义分类模型\n",
    "\n",
    "这里使用ResNet18的模型进行特征提取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:23:38.370681Z",
     "start_time": "2020-05-09T14:23:38.359476Z"
    }
   },
   "outputs": [],
   "source": [
    "class SVHN_Model1(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SVHN_Model1, self).__init__()\n",
    "                \n",
    "        model_conv = models.resnet18(pretrained=True)\n",
    "        model_conv.avgpool = nn.AdaptiveAvgPool2d(1)\n",
    "        model_conv = nn.Sequential(*list(model_conv.children())[:-1])\n",
    "        self.cnn = model_conv\n",
    "        \n",
    "        self.fc1 = nn.Linear(512, 11)\n",
    "        self.fc2 = nn.Linear(512, 11)\n",
    "        self.fc3 = nn.Linear(512, 11)\n",
    "        self.fc4 = nn.Linear(512, 11)\n",
    "        self.fc5 = nn.Linear(512, 11)\n",
    "    \n",
    "    def forward(self, img):        \n",
    "        feat = self.cnn(img)\n",
    "        # print(feat.shape)\n",
    "        feat = feat.view(feat.shape[0], -1)\n",
    "        c1 = self.fc1(feat)\n",
    "        c2 = self.fc2(feat)\n",
    "        c3 = self.fc3(feat)\n",
    "        c4 = self.fc4(feat)\n",
    "        c5 = self.fc5(feat)\n",
    "        return c1, c2, c3, c4, c5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:23:39.461245Z",
     "start_time": "2020-05-09T14:23:39.445117Z"
    }
   },
   "outputs": [],
   "source": [
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    # 切换模型为训练模式\n",
    "    model.train()\n",
    "    train_loss = []\n",
    "    \n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        if use_cuda:\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "            \n",
    "        c0, c1, c2, c3, c4 = model(input)\n",
    "        loss = criterion(c0, target[:, 0]) + \\\n",
    "                criterion(c1, target[:, 1]) + \\\n",
    "                criterion(c2, target[:, 2]) + \\\n",
    "                criterion(c3, target[:, 3]) + \\\n",
    "                criterion(c4, target[:, 4])\n",
    "        \n",
    "        # loss /= 6\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        train_loss.append(loss.item())\n",
    "    return np.mean(train_loss)\n",
    "\n",
    "def validate(val_loader, model, criterion):\n",
    "    # 切换模型为预测模型\n",
    "    model.eval()\n",
    "    val_loss = []\n",
    "\n",
    "    # 不记录模型梯度信息\n",
    "    with torch.no_grad():\n",
    "        for i, (input, target) in enumerate(val_loader):\n",
    "            if use_cuda:\n",
    "                input = input.cuda()\n",
    "                target = target.cuda()\n",
    "            \n",
    "            c0, c1, c2, c3, c4 = model(input)\n",
    "            loss = criterion(c0, target[:, 0]) + \\\n",
    "                    criterion(c1, target[:, 1]) + \\\n",
    "                    criterion(c2, target[:, 2]) + \\\n",
    "                    criterion(c3, target[:, 3]) + \\\n",
    "                    criterion(c4, target[:, 4])\n",
    "            # loss /= 6\n",
    "            val_loss.append(loss.item())\n",
    "    return np.mean(val_loss)\n",
    "\n",
    "def predict(test_loader, model, tta=10):\n",
    "    model.eval()\n",
    "    test_pred_tta = None\n",
    "    \n",
    "    # TTA 次数\n",
    "    for _ in range(tta):\n",
    "        test_pred = []\n",
    "    \n",
    "        with torch.no_grad():\n",
    "            for i, (input, target) in enumerate(test_loader):\n",
    "                if use_cuda:\n",
    "                    input = input.cuda()\n",
    "                \n",
    "                c0, c1, c2, c3, c4 = model(input)\n",
    "                if use_cuda:\n",
    "                    output = np.concatenate([\n",
    "                        c0.data.cpu().numpy(), \n",
    "                        c1.data.cpu().numpy(),\n",
    "                        c2.data.cpu().numpy(), \n",
    "                        c3.data.cpu().numpy(),\n",
    "                        c4.data.cpu().numpy()], axis=1)\n",
    "                else:\n",
    "                    output = np.concatenate([\n",
    "                        c0.data.numpy(), \n",
    "                        c1.data.numpy(),\n",
    "                        c2.data.numpy(), \n",
    "                        c3.data.numpy(),\n",
    "                        c4.data.numpy()], axis=1)\n",
    "                \n",
    "                test_pred.append(output)\n",
    "        \n",
    "        test_pred = np.vstack(test_pred)\n",
    "        if test_pred_tta is None:\n",
    "            test_pred_tta = test_pred\n",
    "        else:\n",
    "            test_pred_tta += test_pred\n",
    "    \n",
    "    return test_pred_tta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练与验证"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:27:06.642180Z",
     "start_time": "2020-05-09T14:23:50.533281Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Train loss: 3.549230361620585 \t Val loss: 3.4495382709503173\n",
      "0.3484\n",
      "Epoch: 1, Train loss: 2.262838746547699 \t Val loss: 3.070304814338684\n",
      "0.4262\n",
      "Epoch: 2, Train loss: 1.9043410422801972 \t Val loss: 2.8098975682258605\n",
      "0.4742\n",
      "Epoch: 3, Train loss: 1.6968964731693268 \t Val loss: 2.6710116691589354\n",
      "0.4948\n",
      "Epoch: 4, Train loss: 1.5359156634807587 \t Val loss: 2.652583309173584\n",
      "0.5126\n",
      "Epoch: 5, Train loss: 1.4319444489479065 \t Val loss: 2.630811668395996\n",
      "0.5247\n",
      "Epoch: 6, Train loss: 1.3359439101219177 \t Val loss: 2.477479434013367\n",
      "0.5383\n",
      "Epoch: 7, Train loss: 1.244316361983617 \t Val loss: 2.458457417488098\n",
      "0.5498\n",
      "Epoch: 8, Train loss: 1.182295233209928 \t Val loss: 2.5932382040023803\n",
      "0.5316\n",
      "Epoch: 9, Train loss: 1.1050671869913737 \t Val loss: 2.4330560541152955\n",
      "0.554\n"
     ]
    }
   ],
   "source": [
    "model = SVHN_Model1()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 0.001)\n",
    "best_loss = 1000.0\n",
    "\n",
    "use_cuda = True\n",
    "if use_cuda:\n",
    "    model = model.cuda()\n",
    "\n",
    "for epoch in range(10):\n",
    "    train_loss = train(train_loader, model, criterion, optimizer, epoch)\n",
    "    val_loss = validate(val_loader, model, criterion)\n",
    "    \n",
    "    val_label = [''.join(map(str, x)) for x in val_loader.dataset.img_label]\n",
    "    val_predict_label = predict(val_loader, model, 1)\n",
    "    val_predict_label = np.vstack([\n",
    "        val_predict_label[:, :11].argmax(1),\n",
    "        val_predict_label[:, 11:22].argmax(1),\n",
    "        val_predict_label[:, 22:33].argmax(1),\n",
    "        val_predict_label[:, 33:44].argmax(1),\n",
    "        val_predict_label[:, 44:55].argmax(1),\n",
    "    ]).T\n",
    "    val_label_pred = []\n",
    "    for x in val_predict_label:\n",
    "        val_label_pred.append(''.join(map(str, x[x!=10])))\n",
    "    \n",
    "    val_char_acc = np.mean(np.array(val_label_pred) == np.array(val_label))\n",
    "    \n",
    "    print('Epoch: {0}, Train loss: {1} \\t Val loss: {2}'.format(epoch, train_loss, val_loss))\n",
    "    print('Val Acc', val_char_acc)\n",
    "    # 记录下验证集精度\n",
    "    if val_loss < best_loss:\n",
    "        best_loss = val_loss\n",
    "        # print('Find better model in Epoch {0}, saving model.'.format(epoch))\n",
    "        torch.save(model.state_dict(), './model.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测并生成提交文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:27:46.875945Z",
     "start_time": "2020-05-09T14:27:46.575526Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "40000 40000\n"
     ]
    }
   ],
   "source": [
    "test_path = glob.glob('../input/test_a/*.png')\n",
    "test_path.sort()\n",
    "test_json = json.load(open('../input/test_a.json'))\n",
    "test_label = [[1]] * len(test_path)\n",
    "print(len(test_path), len(test_label))\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    SVHNDataset(test_path, test_label,\n",
    "                transforms.Compose([\n",
    "                    transforms.Resize((70, 140)),\n",
    "                    # transforms.RandomCrop((60, 120)),\n",
    "                    # transforms.ColorJitter(0.3, 0.3, 0.2),\n",
    "                    # transforms.RandomRotation(5),\n",
    "                    transforms.ToTensor(),\n",
    "                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    ])), \n",
    "    batch_size=40, \n",
    "    shuffle=False, \n",
    "    num_workers=10,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-09T14:27:57.864970Z",
     "start_time": "2020-05-09T14:27:48.691924Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(40000, 55)\n"
     ]
    }
   ],
   "source": [
    "# 加载保存的最优模型\n",
    "model.load_state_dict(torch.load('model.pt'))\n",
    "\n",
    "test_predict_label = predict(test_loader, model, 1)\n",
    "print(test_predict_label.shape)\n",
    "\n",
    "test_label = [''.join(map(str, x)) for x in test_loader.dataset.img_label]\n",
    "test_predict_label = np.vstack([\n",
    "    test_predict_label[:, :11].argmax(1),\n",
    "    test_predict_label[:, 11:22].argmax(1),\n",
    "    test_predict_label[:, 22:33].argmax(1),\n",
    "    test_predict_label[:, 33:44].argmax(1),\n",
    "    test_predict_label[:, 44:55].argmax(1),\n",
    "]).T\n",
    "\n",
    "test_label_pred = []\n",
    "for x in test_predict_label:\n",
    "    test_label_pred.append(''.join(map(str, x[x!=10])))\n",
    "    \n",
    "import pandas as pd\n",
    "df_submit = pd.read_csv('../input/test_A_sample_submit.csv')\n",
    "df_submit['file_code'] = test_label_pred\n",
    "df_submit.to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
